Skip to content

[clang:frontend] Move helper functions to common location for SemaSPIRV #125045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed

[clang:frontend] Move helper functions to common location for SemaSPIRV #125045

wants to merge 5 commits into from

Conversation

bassiounix
Copy link
Contributor

Move helper functions out of clang/lib/Sema/SemaHLSL.cpp into a common location for clang/lib/Sema/SemaSPIRV.cpp to use.

Moved functions are CheckArgTypeIsCorrect and CheckAllArgTypesAreCorrect.

This is a contribution to the issue #123831.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" HLSL HLSL Language Support backend:SPIR-V labels Jan 30, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 30, 2025

@llvm/pr-subscribers-backend-spir-v
@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-clang

Author: Muhammad Bassiouni (bassiounix)

Changes

Move helper functions out of clang/lib/Sema/SemaHLSL.cpp into a common location for clang/lib/Sema/SemaSPIRV.cpp to use.

Moved functions are CheckArgTypeIsCorrect and CheckAllArgTypesAreCorrect.

This is a contribution to the issue #123831.


Full diff: https://github.com/llvm/llvm-project/pull/125045.diff

5 Files Affected:

  • (added) clang/include/clang/Sema/Common.h (+22)
  • (modified) clang/lib/Sema/CMakeLists.txt (+1)
  • (added) clang/lib/Sema/Common.cpp (+65)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+1-27)
  • (modified) clang/lib/Sema/SemaSPIRV.cpp (+6-46)
diff --git a/clang/include/clang/Sema/Common.h b/clang/include/clang/Sema/Common.h
new file mode 100644
index 00000000000000..3f775df8bddb64
--- /dev/null
+++ b/clang/include/clang/Sema/Common.h
@@ -0,0 +1,22 @@
+#ifndef LLVM_CLANG_SEMA_COMMON_H
+#define LLVM_CLANG_SEMA_COMMON_H
+
+#include "clang/Sema/Sema.h"
+
+namespace clang {
+
+using LLVMFnRef = llvm::function_ref<bool(clang::QualType PassedType)>;
+using PairParam = std::pair<unsigned int, unsigned int>;
+using CheckParam = std::variant<PairParam, LLVMFnRef>;
+
+bool CheckArgTypeIsCorrect(
+    Sema *S, Expr *Arg, QualType ExpectedType,
+    llvm::function_ref<bool(clang::QualType PassedType)> Check);
+
+bool CheckAllArgTypesAreCorrect(
+    Sema *SemaPtr, CallExpr *TheCall,
+    std::variant<QualType, std::nullopt_t> ExpectedType, CheckParam Check);
+
+} // namespace clang
+
+#endif
diff --git a/clang/lib/Sema/CMakeLists.txt b/clang/lib/Sema/CMakeLists.txt
index 19cf3a2db00fdc..ddc340a51a3b2d 100644
--- a/clang/lib/Sema/CMakeLists.txt
+++ b/clang/lib/Sema/CMakeLists.txt
@@ -17,6 +17,7 @@ add_clang_library(clangSema
   AnalysisBasedWarnings.cpp
   CheckExprLifetime.cpp
   CodeCompleteConsumer.cpp
+  Common.cpp
   DeclSpec.cpp
   DelayedDiagnostic.cpp
   HeuristicResolver.cpp
diff --git a/clang/lib/Sema/Common.cpp b/clang/lib/Sema/Common.cpp
new file mode 100644
index 00000000000000..72a9e4a2c99ae1
--- /dev/null
+++ b/clang/lib/Sema/Common.cpp
@@ -0,0 +1,65 @@
+#include "clang/Sema/Common.h"
+
+namespace clang {
+
+bool CheckArgTypeIsCorrect(
+    Sema *S, Expr *Arg, QualType ExpectedType,
+    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
+  QualType PassedType = Arg->getType();
+  if (Check(PassedType)) {
+    if (auto *VecTyA = PassedType->getAs<VectorType>())
+      ExpectedType = S->Context.getVectorType(
+          ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
+    S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
+        << PassedType << ExpectedType << 1 << 0 << 0;
+    return true;
+  }
+  return false;
+}
+
+bool CheckAllArgTypesAreCorrect(
+    Sema *SemaPtr, CallExpr *TheCall,
+    std::variant<QualType, std::nullopt_t> ExpectedType, CheckParam Check) {
+  unsigned int NumElts;
+  unsigned int expected;
+  if (auto *n = std::get_if<PairParam>(&Check)) {
+    if (SemaPtr->checkArgCount(TheCall, n->first)) {
+      return true;
+    }
+    NumElts = n->first;
+    expected = n->second;
+  } else {
+    NumElts = TheCall->getNumArgs();
+  }
+
+  for (unsigned i = 0; i < NumElts; i++) {
+    Expr *localArg = TheCall->getArg(i);
+    if (auto *val = std::get_if<QualType>(&ExpectedType)) {
+      if (auto *fn = std::get_if<LLVMFnRef>(&Check)) {
+        return CheckArgTypeIsCorrect(SemaPtr, localArg, *val, *fn);
+      }
+    }
+
+    QualType PassedType = localArg->getType();
+    if (PassedType->getAs<VectorType>() == nullptr) {
+      SemaPtr->Diag(localArg->getBeginLoc(),
+                    diag::err_typecheck_convert_incompatible)
+          << PassedType
+          << SemaPtr->Context.getVectorType(PassedType, expected,
+                                            VectorKind::Generic)
+          << 1 << 0 << 0;
+      return true;
+    }
+  }
+
+  if (std::get_if<PairParam>(&Check)) {
+    if (auto *localArgVecTy =
+            TheCall->getArg(0)->getType()->getAs<VectorType>()) {
+      TheCall->setType(localArgVecTy->getElementType());
+    }
+  }
+
+  return false;
+}
+
+} // namespace clang
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index d748c10455289b..0cc71e4122666c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -27,6 +27,7 @@
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/Specifiers.h"
 #include "clang/Basic/TargetInfo.h"
+#include "clang/Sema/Common.h"
 #include "clang/Sema/Initialization.h"
 #include "clang/Sema/ParsedAttr.h"
 #include "clang/Sema/Sema.h"
@@ -1996,33 +1997,6 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
   return false;
 }
 
-static bool CheckArgTypeIsCorrect(
-    Sema *S, Expr *Arg, QualType ExpectedType,
-    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
-  QualType PassedType = Arg->getType();
-  if (Check(PassedType)) {
-    if (auto *VecTyA = PassedType->getAs<VectorType>())
-      ExpectedType = S->Context.getVectorType(
-          ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
-    S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
-        << PassedType << ExpectedType << 1 << 0 << 0;
-    return true;
-  }
-  return false;
-}
-
-static bool CheckAllArgTypesAreCorrect(
-    Sema *S, CallExpr *TheCall, QualType ExpectedType,
-    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
-  for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
-    Expr *Arg = TheCall->getArg(i);
-    if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
-      return true;
-    }
-  }
-  return false;
-}
-
 static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
   auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
     return !PassedType->hasFloatingRepresentation();
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index dc49fc79073572..df6a3d61056f5e 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -10,7 +10,9 @@
 
 #include "clang/Sema/SemaSPIRV.h"
 #include "clang/Basic/TargetBuiltins.h"
+#include "clang/Sema/Common.h"
 #include "clang/Sema/Sema.h"
+#include <utility>
 
 namespace clang {
 
@@ -20,54 +22,12 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
                                               CallExpr *TheCall) {
   switch (BuiltinID) {
   case SPIRV::BI__builtin_spirv_distance: {
-    if (SemaRef.checkArgCount(TheCall, 2))
-      return true;
-
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTyA = ArgTyA->getAs<VectorType>();
-    if (VTyA == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
-      return true;
-    }
-
-    ExprResult B = TheCall->getArg(1);
-    QualType ArgTyB = B.get()->getType();
-    auto *VTyB = ArgTyB->getAs<VectorType>();
-    if (VTyB == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyB
-          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
-          << 0 << 0;
-      return true;
-    }
-
-    QualType RetTy = VTyA->getElementType();
-    TheCall->setType(RetTy);
-    break;
+    return CheckAllArgTypesAreCorrect(&SemaRef, TheCall, std::nullopt,
+                                      std::make_pair(2, 2));
   }
   case SPIRV::BI__builtin_spirv_length: {
-    if (SemaRef.checkArgCount(TheCall, 1))
-      return true;
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTy = ArgTyA->getAs<VectorType>();
-    if (VTy == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
-      return true;
-    }
-    QualType RetTy = VTy->getElementType();
-    TheCall->setType(RetTy);
-    break;
+    return CheckAllArgTypesAreCorrect(&SemaRef, TheCall, std::nullopt,
+                                      std::make_pair(1, 2));
   }
   }
   return false;

Copy link
Collaborator

@DavidSpickett DavidSpickett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a drive by style comment.

@@ -0,0 +1,65 @@
#include "clang/Sema/Common.h"

namespace clang {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://llvm.org/docs/CodingStandards.html#use-namespace-qualifiers-to-implement-previously-declared-functions

For example clang::CheckArgTypeIsCorrect instead of opening the clang namespace and just having CheckArgTypeIsCorrect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Member

@junlarsen junlarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each file should have license headers, ref https://llvm.org/docs/CodingStandards.html#file-headers

Copy link
Member

@Sirraide Sirraide left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there may be some code duplication here, so moving some of this code around might make sense, but I see a few issues with how this is done at the moment:

  1. If we were to introduce new helpers, then instead of making a new header for this, these functions should just become member functions of the Sema class.

  2. The new version of CheckAllArgTypesAreCorrect() is... way too complicated, in my opinion; I have trouble trying to figure out what two parameters that are a variant and a pair are supposed to mean from looking at the function, and it does feel like this function wants to be at least two or three separate functions...

I think a better approach would be to leave the HLSL functions alone and just factor out a separate helper and put it in SemaSPIRV.cpp, because as-is, this pr isn’t exactly simplifying things...

@bassiounix
Copy link
Contributor Author

bassiounix commented Jan 31, 2025

  1. If we were to introduce new helpers, then instead of making a new header for this, these functions should just become member functions of the Sema class.

Sema or SemaBase?

  1. The new version of CheckAllArgTypesAreCorrect() is... way too complicated, in my opinion; I have trouble trying to figure out what two parameters that are a variant and a pair are supposed to mean from looking at the function, and it does feel like this function wants to be at least two or three separate functions...

I'll split the implementation in 2 functions.

I think a better approach would be to leave the HLSL functions alone and just factor out a separate helper and put it in SemaSPIRV.cpp, because as-is, this pr isn’t exactly simplifying things...

Since there is a common functionalities, I think I'll go with the first approach, that is, declaring them in a common class.
This class is Sema for now as mentioned in the review.

@Sirraide
Copy link
Member

Sema or SemaBase?

Typically we just put those in Sema; I think the only thing in SemaBase at the moment are some diagnostics helpers.

@bassiounix bassiounix requested a review from Sirraide January 31, 2025 07:26
@bassiounix
Copy link
Contributor Author

Huh.. this shouldn't fail! I tested locally before pushing the commits!

Looks like these files aren't mine .. I guess ..

@Sirraide
Copy link
Member

Well, I’ve encountered a bunch of failing tests lately, but the fact that SPIRV and HLSL tests are failing is a bit suspicious

@Sirraide
Copy link
Member

Yup, looking at the logs, there’s definitely an error somewhere in this pr (don’t really have the time to take a closer look at it at the moment unfortunately)

Copy link
Collaborator

@AaronBallman AaronBallman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for trying to take this on, we definitely appreciate when folks try to split things up to improve our ability to maintain the code base!

That said, I don't think we need another file for common functions; Sema.cpp is already where I would expect those to live. If the functionality is only common to HLSL and SPIRV, maybe we want to consider a SemaOffload.cpp or something along those lines? But it's less clear to me how tightly we want to couple the relationships between all the various offload languages. I think this requires a broader design discussion with the impacted offloading maintainers. CC @llvm-beanz @jdoerfert @bader @alexey-bataev

(It's unfortunate that the issue was marked good first issue; this has quite a bit of layering considerations that we've not really thought out yet.)

@bassiounix bassiounix closed this by deleting the head repository Mar 7, 2025
@damyanp damyanp moved this to Closed in HLSL Support Apr 25, 2025
@damyanp damyanp removed this from HLSL Support Jun 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:SPIR-V clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants